Predict whether income exceeds $50K/yr based on census data. Also known as “Adult” dataset. Extraction was done by Barry Becker from the 1994 Census database. Prediction task is to determine whether a person makes over 50K a year. See the data source and description for more information. These data are also used for demonstrating Tensorflow.
The biggest drivers for predicting income over $50k are: marital status (married is better), education (more is better), and sex (male is better). We will explore the continuous and categorical predictors before building statistical models. Data manipulation is carried out in dplyr and visualizations are done in ggplot2 and plotly.
knitr::opts_chunk$set(warning = FALSE, message = FALSE)
library(tidyverse)
library(plotly)
library(pROC)
library(glmnet)
The data can be downloaded from the web. The training and test data are 3.8 MB and 1.9 MB respectively. The missing values are converted from ? to NA.
download.file("https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data",
"data/train_raw.csv")
download.file("https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.test",
"data/test_raw.csv")
Convert the target variable income_bracket into a numeric value. Create a new column age_buckets and remove records with missing values. Apply to both the test and training data. Create interactions if desired. Note: These interactions can be extremely time consuming to model, therefore they are examined here, but are not included in the predictive models.
format.rawdata <- function(data){
data %>%
mutate(label = ifelse(income_bracket == ">50K" | income_bracket == ">50K.", 1, 0)) %>%
mutate(age_buckets = cut(age, c(16, 18, 25, 30, 35, 40, 45, 50, 55, 60, 65, 90))) %>%
select(label, gender, native_country, education, education_num, occupation, workclass, marital_status,
race, age_buckets) %>%
na.omit
}
train <- train_raw %>% format.rawdata
test <- test_raw %>% format.rawdata
Most of the columns in the census data are categorical. We plot a few of the most important columns here. The complete list of categorical columns are:
plot.main.effects <- function(data, x, y){
data %>%
mutate_(group = x, metric = y) %>%
group_by(group) %>%
summarize(percent = 100 * mean(metric)) %>%
ggplot(aes(x = reorder(group, percent), percent)) +
geom_bar(stat="identity", fill = "lightblue4") +
coord_flip() +
labs(y = "Percent", x = "") +
ggtitle(paste("Percent surveyed with incomes over $50k by", x))
}
plot.main.effects(train, "marital_status", "label")
plot.main.effects(train, "gender", "label")
plot.main.effects(train, "education", "label")
We can compare the distribution of the categorical variables for those who earn more than $50k and those who earn less. The complete list of categorical variables are:
plot.continuous <- function(data, x, y, alpha = 0.2, ...){
lab <- stringr::str_replace_all(y, "_", " ") %>% stringr::str_to_title(y)
data %>%
select_(groups = x, y = y) %>%
na.omit %>%
ggplot(aes(y, fill = groups)) + geom_density(alpha = alpha, ...) +
labs(x = lab, y = "") +
ggtitle(paste0("Income by ", lab))
}
# People who earn more also work more, are better educated, and are older
plot.continuous(train_raw, "income_bracket", "age")
plot.continuous(train_raw, "income_bracket", "education_num", adjust = 5)
plot.continuous(train_raw, "income_bracket", "hours_per_week", adjust = 5)
We can examine some two-way and three-way intearcations with choropleth maps:
p <- train %>%
select(education_num, age_buckets, label) %>%
group_by(age_buckets, education_num) %>%
summarize(percent = 100 * mean(label)) %>%
ggplot(aes(education_num, age_buckets, fill = percent)) +
geom_tile() +
labs(x = "Education", y = "Age") +
ggtitle("Percent surveyed with incomes over $50k by age, education")
ggplotly(p)
p <- train %>%
select(age_buckets, education_num, occupation, label) %>%
group_by(age_buckets, education_num, occupation) %>%
summarize(percent = 100 * mean(label)) %>%
ggplot(aes(education_num, age_buckets, fill = percent)) +
geom_tile() +
facet_wrap( ~ occupation) +
labs(x = "Education", y = "Age") +
ggtitle("Percent surveyed with incomes over $50k by age, education, and occupation")
ggplotly(p)
write_csv(train, "data/train.csv")
write_csv(test, "data/test.csv")
The logistic model uses main effects only against the training data. No regularization is applied. We assess the model fit with a hold out sample. We can build logistic models with the stats package.
Gender, education, and marital status are all highly significant. Marrital status in particular is a good predictor of those earning more than $50k.
m1 <- glm(label ~ gender + native_country + education + occupation + workclass + marital_status +
race + age_buckets, binomial, train)
summary(m1)
Call:
glm(formula = label ~ gender + native_country + education + occupation +
workclass + marital_status + race + age_buckets, family = binomial,
data = train)
Deviance Residuals:
Min 1Q Median 3Q Max
-2.6001 -0.5627 -0.2295 -0.0001 3.8200
Coefficients:
Estimate Std. Error z value Pr(>|z|)
(Intercept) -16.15783 81.90324 -0.197 0.84361
genderMale 0.30583 0.05027 6.083 1.18e-09 ***
native_countryCanada -0.67372 0.67788 -0.994 0.32029
native_countryChina -1.78499 0.69178 -2.580 0.00987 **
native_countryColumbia -3.14179 1.00637 -3.122 0.00180 **
native_countryCuba -0.78317 0.68906 -1.137 0.25572
native_countryDominican-Republic -2.13266 0.98598 -2.163 0.03054 *
native_countryEcuador -1.22454 0.93018 -1.316 0.18802
native_countryEl-Salvador -1.39951 0.76844 -1.821 0.06857 .
native_countryEngland -0.53931 0.69006 -0.782 0.43449
native_countryFrance -0.38991 0.81306 -0.480 0.63154
native_countryGermany -0.57296 0.66658 -0.860 0.39003
native_countryGreece -1.70925 0.81647 -2.093 0.03631 *
native_countryGuatemala -1.13839 0.92321 -1.233 0.21755
native_countryHaiti -1.31481 0.88943 -1.478 0.13934
native_countryHoland-Netherlands -13.77494 2399.54480 -0.006 0.99542
native_countryHonduras -1.77411 1.75548 -1.011 0.31220
native_countryHong -1.21231 0.87454 -1.386 0.16568
native_countryHungary -0.88362 0.97130 -0.910 0.36297
native_countryIndia -1.53554 0.66047 -2.325 0.02008 *
native_countryIran -0.93288 0.73550 -1.268 0.20467
native_countryIreland -0.39039 0.87096 -0.448 0.65399
native_countryItaly -0.28367 0.69837 -0.406 0.68460
native_countryJamaica -1.14196 0.74807 -1.527 0.12688
native_countryJapan -0.84284 0.71013 -1.187 0.23527
native_countryLaos -1.67837 1.08622 -1.545 0.12231
native_countryMexico -1.49385 0.65548 -2.279 0.02267 *
native_countryNicaragua -1.64888 1.01535 -1.624 0.10439
native_countryOutlying-US(Guam-USVI-etc) -15.24852 579.46692 -0.026 0.97901
native_countryPeru -2.00894 1.01398 -1.981 0.04756 *
native_countryPhilippines -0.84343 0.63640 -1.325 0.18506
native_countryPoland -1.12283 0.73874 -1.520 0.12853
native_countryPortugal -1.08854 0.89048 -1.222 0.22155
native_countryPuerto-Rico -1.36009 0.72420 -1.878 0.06037 .
native_countryScotland -1.52519 1.11662 -1.366 0.17197
native_countrySouth -1.97837 0.70600 -2.802 0.00508 **
native_countryTaiwan -1.43724 0.73488 -1.956 0.05049 .
native_countryThailand -1.47170 0.99100 -1.485 0.13753
native_countryTrinadad&Tobago -1.47980 1.01391 -1.460 0.14443
native_countryUnited-States -0.80574 0.62400 -1.291 0.19661
native_countryVietnam -2.01444 0.82255 -2.449 0.01432 *
native_countryYugoslavia -0.27549 0.89965 -0.306 0.75944
education11th 0.10758 0.20589 0.523 0.60131
education12th 0.51047 0.26358 1.937 0.05278 .
education1st-4th -0.57722 0.47004 -1.228 0.21943
education5th-6th -0.41266 0.34977 -1.180 0.23809
education7th-8th -0.45579 0.23369 -1.950 0.05113 .
education9th -0.31743 0.26050 -1.219 0.22302
educationAssoc-acdm 1.20443 0.17195 7.004 2.48e-12 ***
educationAssoc-voc 1.20440 0.16503 7.298 2.92e-13 ***
educationBachelors 1.90172 0.15370 12.373 < 2e-16 ***
educationDoctorate 3.10123 0.20970 14.789 < 2e-16 ***
educationHS-grad 0.73165 0.14945 4.896 9.80e-07 ***
educationMasters 2.28057 0.16371 13.930 < 2e-16 ***
educationPreschool -13.08227 301.13781 -0.043 0.96535
educationProf-school 3.05504 0.19480 15.683 < 2e-16 ***
educationSome-college 1.06187 0.15168 7.001 2.55e-12 ***
occupationArmed-Forces -0.87568 1.41321 -0.620 0.53550
occupationCraft-repair -0.03640 0.07594 -0.479 0.63172
occupationExec-managerial 0.86970 0.07205 12.072 < 2e-16 ***
occupationFarming-fishing -0.74541 0.12892 -5.782 7.38e-09 ***
occupationHandlers-cleaners -0.76078 0.13892 -5.477 4.34e-08 ***
occupationMachine-op-inspct -0.38296 0.09814 -3.902 9.53e-05 ***
occupationOther-service -0.91502 0.11293 -8.102 5.39e-16 ***
occupationPriv-house-serv -2.16427 1.02326 -2.115 0.03442 *
occupationProf-specialty 0.53786 0.07667 7.015 2.30e-12 ***
occupationProtective-serv 0.65652 0.12144 5.406 6.43e-08 ***
occupationSales 0.36815 0.07736 4.759 1.95e-06 ***
occupationTech-support 0.55767 0.10697 5.213 1.86e-07 ***
occupationTransport-moving -0.08087 0.09426 -0.858 0.39092
workclassLocal-gov -0.64511 0.10729 -6.013 1.82e-09 ***
workclassPrivate -0.35831 0.08925 -4.015 5.95e-05 ***
workclassSelf-emp-inc 0.04159 0.11751 0.354 0.72340
workclassSelf-emp-not-inc -0.74070 0.10435 -7.098 1.26e-12 ***
workclassState-gov -0.84546 0.11962 -7.068 1.57e-12 ***
workclassWithout-pay -14.96782 542.24463 -0.028 0.97798
marital_statusMarried-AF-spouse 3.24802 0.48912 6.640 3.13e-11 ***
marital_statusMarried-civ-spouse 2.10784 0.06209 33.949 < 2e-16 ***
marital_statusMarried-spouse-absent 0.09028 0.21530 0.419 0.67500
marital_statusNever-married -0.18746 0.07735 -2.424 0.01537 *
marital_statusSeparated -0.10303 0.14811 -0.696 0.48666
marital_statusWidowed 0.33062 0.13960 2.368 0.01787 *
raceAsian-Pac-Islander 0.68870 0.26582 2.591 0.00957 **
raceBlack 0.41753 0.22289 1.873 0.06103 .
raceOther 0.02531 0.35842 0.071 0.94369
raceWhite 0.56547 0.21275 2.658 0.00786 **
age_buckets(18,25] 11.13552 81.90046 0.136 0.89185
age_buckets(25,30] 12.21290 81.90042 0.149 0.88146
age_buckets(30,35] 12.66809 81.90042 0.155 0.87708
age_buckets(35,40] 13.10329 81.90042 0.160 0.87289
age_buckets(40,45] 13.16779 81.90042 0.161 0.87227
age_buckets(45,50] 13.33564 81.90042 0.163 0.87065
age_buckets(50,55] 13.36368 81.90043 0.163 0.87038
age_buckets(55,60] 13.19114 81.90043 0.161 0.87204
age_buckets(60,65] 12.76302 81.90046 0.156 0.87616
age_buckets(65,90] 12.39737 81.90047 0.151 0.87968
---
Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1
(Dispersion parameter for binomial family taken to be 1)
Null deviance: 33851 on 30161 degrees of freedom
Residual deviance: 21510 on 30066 degrees of freedom
AIC: 21702
Number of Fisher Scoring iterations: 15
#anova(m1) # takes a while to run
#plot(m1) # legacy plots not that useful
The high area under the curve (AUC) of 0.883 indicator that this model might be overfitting. The lift chart shows that 80% of those in the uppper decile earn more than $50k, compared to a tiny fraction in the lower decile.
# Predict
pred <- bind_rows("train" = train, "test" = test, .id = "data") %>%
mutate(pred = predict(m1, ., type = "response")) %>%
mutate(decile = ntile(desc(pred), 10)) %>%
select(data, label, pred, decile)
# ROC plot
pred %>%
filter(data == "test") %>%
roc(label ~ pred, .) %>%
plot.roc(., print.auc = TRUE)
# Lift plot
pred %>%
group_by(data, decile) %>%
summarize(percent = 100 * mean(label)) %>%
ggplot(aes(decile, percent, fill = data)) + geom_bar(stat = "Identity", position = "dodge") +
ggtitle("Lift chart for logistic regression model")
The elastic net is a regularized regression method that uses L1 (lasso) and L2 (ridge) penalties. We can build elastic net models with the glmnet package.
Whereas the logistic method used formulas, the elastic net model requires us to construct a model matrix from the categorical predictors. We then attempt to choose a value lambda that optimizes the L1 and L2 penalties. We can examine various predictor sets for different values of lambda. Optionally, we can use cross validation to programmatically determine the best choice of lambda.
# Convert to factors
alldata <- bind_rows("train" = train, "test" = test, .id = "data") %>%
select(-education_num) %>%
mutate_each(., funs(factor(.))) %>%
model.matrix( ~ ., .)
# Create training prediction matrix
train.factors <- list(x = alldata[alldata[,'datatrain'] == 1, -(1:3)],
y = alldata[alldata[,'datatrain'] == 1, 3])
# Create test prediction matrix
test.factors <- list(x = alldata[alldata[,'datatrain'] == 0, -(1:3)],
y = alldata[alldata[,'datatrain'] == 0, 3])
# Fit a regularized model
fit1 <- glmnet(train.factors$x, train.factors$y, family = "binomial")
plot(fit1)
print(fit1)
Call: glmnet(x = train.factors$x, y = train.factors$y, family = "binomial")
Df %Dev Lambda
[1,] 0 -2.170e-13 0.1926000
[2,] 1 2.996e-02 0.1755000
[3,] 1 5.485e-02 0.1599000
[4,] 1 7.566e-02 0.1457000
[5,] 1 9.312e-02 0.1327000
[6,] 1 1.078e-01 0.1210000
[7,] 1 1.202e-01 0.1102000
[8,] 1 1.308e-01 0.1004000
[9,] 1 1.396e-01 0.0915000
[10,] 1 1.472e-01 0.0833700
[11,] 2 1.564e-01 0.0759600
[12,] 3 1.710e-01 0.0692100
[13,] 5 1.852e-01 0.0630700
[14,] 5 1.989e-01 0.0574600
[15,] 7 2.121e-01 0.0523600
[16,] 7 2.246e-01 0.0477100
[17,] 8 2.356e-01 0.0434700
[18,] 8 2.461e-01 0.0396100
[19,] 8 2.551e-01 0.0360900
[20,] 10 2.629e-01 0.0328800
[21,] 11 2.713e-01 0.0299600
[22,] 13 2.788e-01 0.0273000
[23,] 14 2.855e-01 0.0248700
[24,] 14 2.913e-01 0.0226600
[25,] 17 2.970e-01 0.0206500
[26,] 17 3.023e-01 0.0188200
[27,] 23 3.074e-01 0.0171400
[28,] 26 3.127e-01 0.0156200
[29,] 27 3.177e-01 0.0142300
[30,] 28 3.221e-01 0.0129700
[31,] 30 3.261e-01 0.0118200
[32,] 34 3.300e-01 0.0107700
[33,] 35 3.337e-01 0.0098110
[34,] 35 3.369e-01 0.0089390
[35,] 37 3.396e-01 0.0081450
[36,] 38 3.421e-01 0.0074220
[37,] 39 3.443e-01 0.0067620
[38,] 40 3.463e-01 0.0061620
[39,] 40 3.479e-01 0.0056140
[40,] 42 3.496e-01 0.0051150
[41,] 43 3.512e-01 0.0046610
[42,] 44 3.526e-01 0.0042470
[43,] 47 3.538e-01 0.0038700
[44,] 50 3.548e-01 0.0035260
[45,] 51 3.557e-01 0.0032130
[46,] 54 3.565e-01 0.0029270
[47,] 57 3.572e-01 0.0026670
[48,] 60 3.578e-01 0.0024300
[49,] 63 3.583e-01 0.0022140
[50,] 65 3.588e-01 0.0020180
[51,] 66 3.592e-01 0.0018380
[52,] 67 3.597e-01 0.0016750
[53,] 71 3.601e-01 0.0015260
[54,] 73 3.604e-01 0.0013910
[55,] 74 3.607e-01 0.0012670
[56,] 74 3.610e-01 0.0011550
[57,] 77 3.612e-01 0.0010520
[58,] 78 3.615e-01 0.0009585
[59,] 80 3.618e-01 0.0008734
[60,] 81 3.621e-01 0.0007958
[61,] 81 3.623e-01 0.0007251
[62,] 81 3.624e-01 0.0006607
[63,] 82 3.625e-01 0.0006020
[64,] 82 3.627e-01 0.0005485
[65,] 84 3.628e-01 0.0004998
[66,] 88 3.628e-01 0.0004554
[67,] 90 3.629e-01 0.0004149
[68,] 91 3.631e-01 0.0003781
[69,] 91 3.633e-01 0.0003445
[70,] 92 3.634e-01 0.0003139
[71,] 91 3.634e-01 0.0002860
[72,] 91 3.635e-01 0.0002606
[73,] 92 3.635e-01 0.0002374
[74,] 92 3.636e-01 0.0002163
[75,] 92 3.636e-01 0.0001971
[76,] 90 3.638e-01 0.0001796
[77,] 90 3.638e-01 0.0001637
(m2 <- coef.glmnet(fit1, s = 0.02)) # extract coefficients at a single value of lambda
96 x 1 sparse Matrix of class "dgCMatrix"
1
(Intercept) -2.48987926
genderMale .
native_countryCanada .
native_countryChina .
native_countryColumbia .
native_countryCuba .
native_countryDominican-Republic .
native_countryEcuador .
native_countryEl-Salvador .
native_countryEngland .
native_countryFrance .
native_countryGermany .
native_countryGreece .
native_countryGuatemala .
native_countryHaiti .
native_countryHoland-Netherlands .
native_countryHonduras .
native_countryHong .
native_countryHungary .
native_countryIndia .
native_countryIran .
native_countryIreland .
native_countryItaly .
native_countryJamaica .
native_countryJapan .
native_countryLaos .
native_countryMexico .
native_countryNicaragua .
native_countryOutlying-US(Guam-USVI-etc) .
native_countryPeru .
native_countryPhilippines .
native_countryPoland .
native_countryPortugal .
native_countryPuerto-Rico .
native_countryScotland .
native_countrySouth .
native_countryTaiwan .
native_countryThailand .
native_countryTrinadad&Tobago .
native_countryUnited-States .
native_countryVietnam .
native_countryYugoslavia .
education11th .
education12th .
education1st-4th .
education5th-6th .
education7th-8th -0.10172508
education9th .
educationAssoc-acdm .
educationAssoc-voc .
educationBachelors 0.69417178
educationDoctorate 1.04226523
educationHS-grad -0.03860506
educationMasters 0.89614956
educationPreschool .
educationProf-school 1.19493744
educationSome-college .
occupationArmed-Forces .
occupationCraft-repair .
occupationExec-managerial 0.69611385
occupationFarming-fishing -0.06129799
occupationHandlers-cleaners .
occupationMachine-op-inspct .
occupationOther-service -0.28702345
occupationPriv-house-serv .
occupationProf-specialty 0.43502928
occupationProtective-serv .
occupationSales .
occupationTech-support .
occupationTransport-moving .
workclassLocal-gov .
workclassPrivate .
workclassSelf-emp-inc 0.20665622
workclassSelf-emp-not-inc .
workclassState-gov .
workclassWithout-pay .
marital_statusMarried-AF-spouse .
marital_statusMarried-civ-spouse 1.85107323
marital_statusMarried-spouse-absent .
marital_statusNever-married -0.07063281
marital_statusSeparated .
marital_statusWidowed .
raceAsian-Pac-Islander .
raceBlack .
raceOther .
raceWhite .
age_buckets(18,25] -0.72088908
age_buckets(25,30] -0.24146074
age_buckets(30,35] .
age_buckets(35,40] .
age_buckets(40,45] .
age_buckets(45,50] 0.08753949
age_buckets(50,55] 0.01396170
age_buckets(55,60] .
age_buckets(60,65] .
age_buckets(65,90] .
# Cross validation (long running for full dataset)
cvfit <- cv.glmnet(train.factors$x, train.factors$y, family = "binomial", type.measure = "class")
plot(cvfit)
cvfit$lambda.min # 0.0001971255
Once you have chosen a value for lambda you can score the test set and examine the ROC and lift charts. This model has a slightly smaller AUC and lift values, but the overall results look very similar to logistic regression.
# Predict and plot the AUC
test.factors$pred <- predict(fit1, test.factors$x, s=0.02, type = "response") # make predictions
data.frame(resp = test.factors$y, pred = c(test.factors$pred)) %>%
roc(resp ~ pred, .) %>%
plot.roc(., print.auc = TRUE)
# Lift chart
data.frame(data = ifelse(alldata[, 'datatrain'], "train", "test"),
label = alldata[,'label1'],
pred = c(predict.glmnet(fit1, alldata[, -(1:3)], s=0.02))) %>%
mutate(decile = ntile(desc(pred), 10)) %>%
group_by(data, decile) %>%
summarize(percent = 100 * mean(label)) %>%
ggplot(aes(decile, percent, fill = data)) + geom_bar(stat = "Identity", position = "dodge") +
ggtitle("Lift chart for elastic net model")
Finally, save the predicted output and the model for building apps.
# Score predictions
pred.out <- test %>%
mutate(pred.glm = pred$pred[pred$data == "test"]) %>%
mutate(pred.net = c(test.factors$pred)) %>%
mutate(income_bracket = ifelse(label, ">50K", "<=50K")
)
# Output predictions to file
write_csv(pred.out, "data/pred.csv")
saveRDS(m1, file = "data/logisticModel.rds")
saveRDS(m2, file = "data/elasticnetModel.rds")
If you want to try other models, take a look at the caret package. The caret package (short for _C_lassification _A_nd _RE_gression _T_raining) is a set of functions that attempt to streamline the process for creating predictive models. The package contains tools for:
as well as other functionality. See the caret documentation for more details.
library(caret)
library(e1071)
library(gbm)
## convert label to factor
train$y <- factor(train$label)
## Cross validation
fitControl <- trainControl(method = "cv", number = 3, repeats = 1)
## Fit a gbm model with cross validation (this will take a long time!)
gbmFit1 <- train(y ~ gender + education + occupation + workclass + marital_status + age_buckets,
data = train,
method = "gbm",
trControl = fitControl,
verbose = FALSE)
## Summarize
summary(gbmFit1)